{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    },
    "colab": {
      "name": "example.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uhBk_wpdj6ep",
        "colab_type": "text"
      },
      "source": [
        "# Name Entity Recognition\n",
        "## Bidirectional-LSTM-CRF model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FTWPwDKej6eu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import os\n",
        "import tensorflow as tf\n",
        "import tensorflow_addons"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xSqtHz2dowmv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "8d2881d8-2d63-45af-d718-71dc5b2d3af7"
      },
      "source": [
        "tf.keras.backend.clear_session"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<function tensorflow.python.keras.backend.clear_session>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_D9Qc2CLwK3m",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "13e608ad-ec36-46f1-8216-1f511500a451"
      },
      "source": [
        "!git clone https://github.com/quocdat32461997/BiLSTM-CRF.git\n",
        "!mv BiLSTM-CRF BiLSTM_CRF"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "fatal: destination path 'BiLSTM-CRF' already exists and is not an empty directory.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GkPr6RWBj6fD",
        "colab_type": "text"
      },
      "source": [
        "### Download Entity-Annotated-Corpus\n",
        "\n",
        "#### Option 1: through Google Drive\n",
        "* Download corpus dataset from link: https://drive.google.com/file/d/1JZ4JXuJrEG1e9OiM1PEoRVtd9wcAIVz9/view?usp=sharing\n",
        "* Upload corpus dataset back to your Google Colab under **/content** directory\n",
        "\n",
        "\n",
        "#### Option 2: using Kaggle API \n",
        "* Generate and download Kaggle API token as **kaggle.json** file to **/content** directory of Google Colab\n",
        "* Move **kaggle.json** to **~/.kaggle/kaggle.json** by command: **!mv kaggle.json ~/.kaggle/kaggle.json**\n",
        "* Provide access by command: **!chmod 600 ~/.kaggle/kaggle.json**\n",
        "* Download corpus dataset: **!kaggle datasets download -d abhinavwalia95/entity-annotated-corpus**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D3lD7iGej6fJ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 69
        },
        "outputId": "6c3ee224-ce2e-4e1f-e733-184e04f0da83"
      },
      "source": [
        "try:\n",
        "    !mv kaggle.json ~/.kaggle/kaggle.json\n",
        "    !chmod 600 ~/.kaggle/kaggle.json\n",
        "    !kaggle datasets download -d abhinavwalia95/entity-annotated-corpus --unzip --force\n",
        "except:\n",
        "    print(\"Please see Option 1 to get Entity-Annotated-Corpus\")"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading entity-annotated-corpus.zip to /content\n",
            " 64% 17.0M/26.4M [00:00<00:00, 18.4MB/s]\n",
            "100% 26.4M/26.4M [00:01<00:00, 26.0MB/s]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2pMpQr-Wj6fW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"\n",
        "utils.py - module to implement utils for BiLSTM-CRF\n",
        "\"\"\"\n",
        "\n",
        "import tensorflow as tf\n",
        "import pandas as pd\n",
        "from sklearn.utils import shuffle\n",
        "from tensorflow.keras.utils import Sequence, to_categorical\n",
        "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
        "\n",
        "class SentenceGetter(Sequence):\n",
        "    \"\"\"\n",
        "    Inheritted class from tf.keras.utils.Sequence class to efficiently load\n",
        "    data to Tensorflow/Keras model\n",
        "    \"\"\"\n",
        "    def __init__(self, data, words, tags, maxlen, batch_size = 16, shuffle = False):\n",
        "        \"\"\"\n",
        "        __init__ - initializer for SentenceGetter class\n",
        "        Inputs:\n",
        "            - data : String or Pandas DataFrame object\n",
        "                Path string to file or dataframe object\n",
        "            - words : set\n",
        "                Set of distinct words\n",
        "            - tags : set\n",
        "                Set of distinct tags\n",
        "        \"\"\"\n",
        "\n",
        "        if isinstance(data, str):\n",
        "            # load data from Pandas file path\n",
        "            data = pd.read_csv(data, encoding = 'latin1')\n",
        "        elif isinstance(data, pd.DataFrame):\n",
        "            # load data from Pandas DataFrame\n",
        "            data = data\n",
        "        else:\n",
        "            raise Exception('Data is None or not found')\n",
        "        self.word2dix = {w : i + 1 for i, w in enumerate(words)}\n",
        "        self.tag2dix = {t : i for i, t in enumerate(tags)}\n",
        "        self.n_tags = len(tags)\n",
        "        self.batch_size = batch_size\n",
        "        self.shuffle = shuffle\n",
        "        self.maxlen = maxlen\n",
        "        n_sent = 1\n",
        "        self.grouped = data.groupby('Sentence #').apply(self.agg_func)\\\n",
        "            .reset_index().rename(columns = {0 : 'sentence'})['sentence']\n",
        "        self.sentences = [s for s in self.grouped]\n",
        "\n",
        "    def agg_func(self, input):\n",
        "        \"\"\"\n",
        "        agg_func - function to group words/tags of sentences together\n",
        "        \"\"\"\n",
        "        return [(w, p, t) for w, p, t in zip(input['Word'].values.tolist(),\n",
        "            input['POS'].values.tolist(), input['Tag'].values.tolist())]\n",
        "\n",
        "    def pad_sentences(self, input):\n",
        "        input = [[self.word2dix[w[0]] for w in s] for s in input]\n",
        "        return pad_sequences(maxlen = self.maxlen, sequences = input, padding = 'post', value = 0)\n",
        "    \n",
        "    def generate_labels(self, input):\n",
        "        input = [[self.tag2dix[w[2]] for w in s] for s in input]\n",
        "        input = pad_sequences(maxlen = self.maxlen, sequences = input, padding = 'post', value = self.tag2dix['O'])\n",
        "        #return input\n",
        "        return np.array([to_categorical(x, num_classes = self.n_tags) for x in input])\n",
        "    def __len__(self):\n",
        "        \"\"\"\n",
        "        __len__ - function to compute length of SentenceGetter\n",
        "        \"\"\"\n",
        "        return int(self.grouped.shape[0] // self.batch_size)\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        if index == 0 and self.shuffle:\n",
        "            # shuffle dataset for every iteration\n",
        "            self.grouped = shuffle(self.grouped).reset_index()['sentence']\n",
        "\n",
        "        # get batch of data\n",
        "        sentences = self.grouped[index * self.batch_size : (index + 1) * self.batch_size]\n",
        "        \n",
        "        # generate sentences and labels\n",
        "        labels = self.generate_labels(sentences)\n",
        "        sentences = self.pad_sentences(sentences)\n",
        "        \n",
        "        return sentences, labels"
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P0dXUAzhj6fi",
        "colab_type": "text"
      },
      "source": [
        "### Import corpus"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0bHRIHgrj6fk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "data = pd.read_csv('ner_dataset.csv', encoding = 'latin1')\n",
        "data = data.fillna(method=\"ffill\")"
      ],
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZXCmi-p1j6fv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "outputId": "ce27b229-7740-4c87-f08a-d45f938c2406"
      },
      "source": [
        "data.head()"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Sentence #</th>\n",
              "      <th>Word</th>\n",
              "      <th>POS</th>\n",
              "      <th>Tag</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Sentence: 1</td>\n",
              "      <td>Thousands</td>\n",
              "      <td>NNS</td>\n",
              "      <td>O</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Sentence: 1</td>\n",
              "      <td>of</td>\n",
              "      <td>IN</td>\n",
              "      <td>O</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>Sentence: 1</td>\n",
              "      <td>demonstrators</td>\n",
              "      <td>NNS</td>\n",
              "      <td>O</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Sentence: 1</td>\n",
              "      <td>have</td>\n",
              "      <td>VBP</td>\n",
              "      <td>O</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>Sentence: 1</td>\n",
              "      <td>marched</td>\n",
              "      <td>VBN</td>\n",
              "      <td>O</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "    Sentence #           Word  POS Tag\n",
              "0  Sentence: 1      Thousands  NNS   O\n",
              "1  Sentence: 1             of   IN   O\n",
              "2  Sentence: 1  demonstrators  NNS   O\n",
              "3  Sentence: 1           have  VBP   O\n",
              "4  Sentence: 1        marched  VBN   O"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AgOyuf-kj6f7",
        "colab_type": "text"
      },
      "source": [
        "#### Build Dictionary of Words and Tags"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9jVoFMlmj6f9",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "d7599aab-6f6d-4f46-b7e7-a1bab29cb373"
      },
      "source": [
        "# build list of distinct words\n",
        "words = list(set(data[\"Word\"].values))\n",
        "words.append(\"ENDPAD\")\n",
        "n_words = len(words); n_words"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "35179"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AQ4cDFNcj6gH",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "5955b12d-f7d7-4b10-ba4e-fa33712edb0c"
      },
      "source": [
        "# build list of distinct tags\n",
        "tags = list(set(data[\"Tag\"].values))\n",
        "n_tags = len(tags); n_tags"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "17"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4HtwDFtcj6gQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# build dictionary of words and tags\n",
        "max_len = 75\n",
        "word2dix = {w : i + 1 for i, w in enumerate(words)}\n",
        "tag2dix = {t : i for i, t in enumerate(tags)}"
      ],
      "execution_count": 21,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A0uOa7s0j6gY",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "23ea18ac-f2fe-4ea5-f295-5a0b7e115c0b"
      },
      "source": [
        "word2dix['Obama']"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "23870"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OYKx-9Thj6gh",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "654dccd3-a05b-420d-cd2b-5a2c84ae7b76"
      },
      "source": [
        "tag2dix['O']"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "7"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jhgqNwUbj6gq",
        "colab_type": "text"
      },
      "source": [
        "#### Generate Sentence Getter"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xu_7uJI7j6gs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_size = 32\n",
        "getter = SentenceGetter(data = data, words = words, tags = tags, maxlen = max_len, batch_size = batch_size)"
      ],
      "execution_count": 37,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B6h4SbeWj6g0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sentences = getter.sentences"
      ],
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w4sIsft7j6g-",
        "colab_type": "text"
      },
      "source": [
        "#### Tokenize and prepare sentences"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pv4eLC0Mj6hA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensorflow.keras.preprocessing.sequence import pad_sequences"
      ],
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "snED-JKmj6hG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X = [[word2dix[w[0]] for w in s] for s in sentences]\n",
        "X = pad_sequences(maxlen = max_len, sequences = X, padding = 'post', value = 0)"
      ],
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RQjV6Wpij6hQ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 434
        },
        "outputId": "7c980d72-e660-45ff-9904-fd546a38998c"
      },
      "source": [
        "sentences[0]"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[('Thousands', 'NNS', 'O'),\n",
              " ('of', 'IN', 'O'),\n",
              " ('demonstrators', 'NNS', 'O'),\n",
              " ('have', 'VBP', 'O'),\n",
              " ('marched', 'VBN', 'O'),\n",
              " ('through', 'IN', 'O'),\n",
              " ('London', 'NNP', 'B-geo'),\n",
              " ('to', 'TO', 'O'),\n",
              " ('protest', 'VB', 'O'),\n",
              " ('the', 'DT', 'O'),\n",
              " ('war', 'NN', 'O'),\n",
              " ('in', 'IN', 'O'),\n",
              " ('Iraq', 'NNP', 'B-geo'),\n",
              " ('and', 'CC', 'O'),\n",
              " ('demand', 'VB', 'O'),\n",
              " ('the', 'DT', 'O'),\n",
              " ('withdrawal', 'NN', 'O'),\n",
              " ('of', 'IN', 'O'),\n",
              " ('British', 'JJ', 'B-gpe'),\n",
              " ('troops', 'NNS', 'O'),\n",
              " ('from', 'IN', 'O'),\n",
              " ('that', 'DT', 'O'),\n",
              " ('country', 'NN', 'O'),\n",
              " ('.', '.', 'O')]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r45lHffEj6hX",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "3ea65342-5b97-40f6-d09b-50d1f460a76e"
      },
      "source": [
        "X[0].shape"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(75,)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PzwOHOdLj6hf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "y = [[tag2dix[w[2]] for w in s] for s in sentences]\n",
        "y = pad_sequences(maxlen=max_len, sequences=y, padding=\"post\", value=tag2dix[\"O\"])"
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tvYkYNw3j6hk",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        },
        "outputId": "398a9387-d081-4e97-9774-a2f5de3d72aa"
      },
      "source": [
        "y[1]"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([ 8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  4,  0,\n",
              "        0,  0, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
              "        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
              "        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
              "        0,  0,  0,  0,  0,  0,  0], dtype=int32)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "asG9UdIwj6hs",
        "colab_type": "text"
      },
      "source": [
        "#### Build BiLSTM-CRF model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vY8U3ySSj6hs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"\"\"\n",
        "output.py - module to store BiLSTM-CRF model\n",
        "\"\"\"\n",
        "\n",
        "from tensorflow.keras import Model\n",
        "from tensorflow.keras.layers import Input, LSTM, Lambda, Embedding, TimeDistributed, Dropout, Bidirectional, Dense, Layer, InputSpec\n",
        "import tensorflow_addons as tfa\n",
        "from tensorflow_addons.text import crf_log_likelihood, viterbi_decode, crf_decode\n",
        "import tensorflow.keras.backend as K\n",
        "\n",
        "def embedding_layer(input_dim, output_dim, input_length, mask_zero):\n",
        "    return Embedding(input_dim = input_dim, output_dim = output_dim, input_length = input_length, mask_zero = mask_zero)\n",
        "    \n",
        "def bilstm_crf(maxlen, n_tags, embedding_dim, n_words, mask_zero, training = True):\n",
        "    \"\"\"\n",
        "    bilstm_crf - module to build BiLSTM-CRF model\n",
        "    Inputs:\n",
        "        - input_shape : tuple\n",
        "            Tensor shape of inputs, excluding batch size\n",
        "    Outputs:\n",
        "        - output : tensorflow.keras.outputs.output\n",
        "            BiLSTM-CRF output\n",
        "    \"\"\"\n",
        "    input = Input(shape = (maxlen,))\n",
        "    # Embedding layer\n",
        "    embeddings = embedding_layer(input_dim = n_words + 1, output_dim = embedding_dim, input_length = maxlen, mask_zero = mask_zero)\n",
        "    output = embeddings(input)\n",
        "\n",
        "    # BiLSTM layer\n",
        "    output = Bidirectional(LSTM(units = 50, return_sequences = True, recurrent_dropout = 0.1))(output)\n",
        "\n",
        "    # Dense layer\n",
        "    output = TimeDistributed(Dense(n_tags, activation = 'relu'))(output)\n",
        "    \n",
        "    output = CRF(n_tags, name = 'crf_layer')(output)\n",
        "    return Model(input, output)\n",
        "\n",
        "\n",
        "class CRF(Layer):\n",
        "    def __init__(self,\n",
        "                 output_dim,\n",
        "                 sparse_target=True,\n",
        "                 **kwargs):\n",
        "        \"\"\"    \n",
        "        Args:\n",
        "            output_dim (int): the number of labels to tag each temporal input.\n",
        "            sparse_target (bool): whether the the ground-truth label represented in one-hot.\n",
        "        Input shape:\n",
        "            (batch_size, sentence length, output_dim)\n",
        "        Output shape:\n",
        "            (batch_size, sentence length, output_dim)\n",
        "        \"\"\"\n",
        "        super(CRF, self).__init__(**kwargs)\n",
        "        self.output_dim = int(output_dim) \n",
        "        self.sparse_target = sparse_target\n",
        "        self.input_spec = InputSpec(min_ndim=3)\n",
        "        self.supports_masking = False\n",
        "        self.sequence_lengths = None\n",
        "        self.transitions = None\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        assert len(input_shape) == 3\n",
        "        f_shape = tf.TensorShape(input_shape)\n",
        "        input_spec = InputSpec(min_ndim=3, axes={-1: f_shape[-1]})\n",
        "\n",
        "        if f_shape[-1] is None:\n",
        "            raise ValueError('The last dimension of the inputs to `CRF` '\n",
        "                             'should be defined. Found `None`.')\n",
        "        if f_shape[-1] != self.output_dim:\n",
        "            raise ValueError('The last dimension of the input shape must be equal to output'\n",
        "                             ' shape. Use a linear layer if needed.')\n",
        "        self.input_spec = input_spec\n",
        "        self.transitions = self.add_weight(name='transitions',\n",
        "                                           shape=[self.output_dim, self.output_dim],\n",
        "                                           initializer='glorot_uniform',\n",
        "                                           trainable=True)\n",
        "        self.built = True\n",
        "\n",
        "    def compute_mask(self, inputs, mask=None):\n",
        "        # Just pass the received mask from previous layer, to the next layer or\n",
        "        # manipulate it if this layer changes the shape of the input\n",
        "        return mask\n",
        "\n",
        "    def call(self, inputs, sequence_lengths=None, training=None, **kwargs):\n",
        "        sequences = tf.convert_to_tensor(inputs, dtype=self.dtype)\n",
        "        if sequence_lengths is not None:\n",
        "            assert len(sequence_lengths.shape) == 2\n",
        "            assert tf.convert_to_tensor(sequence_lengths).dtype == 'int32'\n",
        "            seq_len_shape = tf.convert_to_tensor(sequence_lengths).get_shape().as_list()\n",
        "            assert seq_len_shape[1] == 1\n",
        "            self.sequence_lengths = K.flatten(sequence_lengths)\n",
        "        else:\n",
        "            self.sequence_lengths = tf.ones(tf.shape(inputs)[0], dtype=tf.int32) * (\n",
        "                tf.shape(inputs)[1]\n",
        "            )\n",
        "\n",
        "        viterbi_sequence, _ = crf_decode(sequences,\n",
        "                                         self.transitions,\n",
        "                                         self.sequence_lengths)\n",
        "        output = K.one_hot(viterbi_sequence, self.output_dim)\n",
        "        return K.in_train_phase(sequences, output)\n",
        "\n",
        "    @property\n",
        "    def loss(self):\n",
        "        def crf_loss(y_true, y_pred):\n",
        "            y_pred = tf.convert_to_tensor(y_pred, dtype=self.dtype)\n",
        "            log_likelihood, self.transitions = crf_log_likelihood(\n",
        "                y_pred,\n",
        "                tf.cast(K.argmax(y_true), dtype=tf.int32) if self.sparse_target else y_true,\n",
        "                self.sequence_lengths,\n",
        "                transition_params=self.transitions,\n",
        "            )\n",
        "            return tf.reduce_mean(-log_likelihood)\n",
        "        return crf_loss\n",
        "\n",
        "    @property\n",
        "    def accuracy(self):\n",
        "        def viterbi_accuracy(y_true, y_pred):\n",
        "            # -1e10 to avoid zero at sum(mask)\n",
        "            mask = K.cast(\n",
        "                K.all(K.greater(y_pred, -1e10), axis=2), K.floatx())\n",
        "            shape = tf.shape(y_pred)\n",
        "            sequence_lengths = tf.ones(shape[0], dtype=tf.int32) * (shape[1])\n",
        "            y_pred, _ = crf_decode(y_pred, self.transitions, sequence_lengths)\n",
        "            if self.sparse_target:\n",
        "                y_true = K.argmax(y_true, 2)\n",
        "            y_pred = K.cast(y_pred, 'int32')\n",
        "            y_true = K.cast(y_true, 'int32')\n",
        "            corrects = K.cast(K.equal(y_true, y_pred), K.floatx())\n",
        "            return K.sum(corrects * mask) / K.sum(mask)\n",
        "        return viterbi_accuracy\n",
        "\n",
        "    def compute_output_shape(self, input_shape):\n",
        "        tf.TensorShape(input_shape).assert_has_rank(3)\n",
        "        return input_shape[:2] + (self.output_dim,)\n",
        "\n",
        "    def get_config(self):\n",
        "        config = {\n",
        "            'output_dim': self.output_dim,\n",
        "            'sparse_target': self.sparse_target,\n",
        "            'supports_masking': self.supports_masking,\n",
        "            'transitions': K.eval(self.transitions)\n",
        "        }\n",
        "        base_config = super(CRF, self).get_config()\n",
        "        return dict(base_config, **config)"
      ],
      "execution_count": 38,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oid0TQNPj6h0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 382
        },
        "outputId": "e15e845c-f515-4300-fa47-a1c8cfb3535a"
      },
      "source": [
        "model = bilstm_crf(maxlen = max_len, n_tags = n_tags, embedding_dim = 20, n_words = n_words, mask_zero = True)\n",
        "model.summary()"
      ],
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Layer lstm will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
            "WARNING:tensorflow:Layer lstm will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
            "WARNING:tensorflow:Layer lstm will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU\n",
            "Model: \"functional_1\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "input_1 (InputLayer)         [(None, 75)]              0         \n",
            "_________________________________________________________________\n",
            "embedding (Embedding)        (None, 75, 20)            703600    \n",
            "_________________________________________________________________\n",
            "bidirectional (Bidirectional (None, 75, 100)           28400     \n",
            "_________________________________________________________________\n",
            "time_distributed (TimeDistri (None, 75, 17)            1717      \n",
            "_________________________________________________________________\n",
            "crf_layer (CRF)              (None, 75, 17)            289       \n",
            "=================================================================\n",
            "Total params: 734,006\n",
            "Trainable params: 734,006\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q_uT6UlioJJK",
        "colab_type": "text"
      },
      "source": [
        "### Configure and train model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OsLE42mmj6h8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensorflow.keras.optimizers import Adam\n",
        "model.compile(optimizer = Adam(learning_rate = 0.01), loss = model.layers[-1].loss, metrics = model.layers[-1].accuracy)"
      ],
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9IjQTAv3j6iF",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 159
        },
        "outputId": "160c334d-860b-4dc5-aa02-8f142397b0cb"
      },
      "source": [
        "from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping\n",
        "logging = TensorBoard(log_dir = './logs')\n",
        "checkpoint = ModelCheckpoint('logs' + '/v1/ep{epoch:03d}-oss{loss:.3f}.h5', monitor = 'loss', save_weights_only = True, save_best_only = True, period = 3)\n",
        "reduce_lr = ReduceLROnPlateau(monitor = 'loss', factor = 0.1, patience = 5, verbose = 1)\n",
        "early_stopping = EarlyStopping(monitor = 'loss', min_delta = 0, patience = 10, verbose = 1)\n",
        "steps_per_epoch = getter.__len__() / 4\n",
        "model.fit(getter, epochs = 200, initial_epoch = 0, steps_per_epoch=steps_per_epoch, callbacks = [logging, checkpoint, reduce_lr], verbose = 1, shuffle = True)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.\n",
            "Epoch 1/200\n",
            "  1/374 [..............................] - ETA: 0s - loss: 59.8860 - viterbi_accuracy: 8.3333e-04WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.\n",
            "Instructions for updating:\n",
            "use `tf.profiler.experimental.stop` instead.\n",
            "  2/374 [..............................] - ETA: 14:18 - loss: 56.2584 - viterbi_accuracy: 0.4521 WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 1.2035s vs `on_train_batch_end` time: 3.4067s). Check your callbacks.\n",
            "141/374 [==========>...................] - ETA: 3:05 - loss: 17.7238 - viterbi_accuracy: 0.9508"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sVCVDs7Kj6iL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tf.keras.models.save_model(model, filepath = 'bilstm_crf.tf', save_format = 'tf')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1dTq4PCTvcmA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}